# https://leetcode.com/problems/minimize-malware-spread/

import collections

# union find, TC:O(N^2*alpha(N)), SC:O(n)
def minMalwareSpread(graph: List[List[int]], initial: List[int]) -> int:
    n = len(graph)
    parents = [i for i in range(n)]

    # connect all nodes => count nodes of each groups => find parents from initial
    def find(x):
        if x != parents[x]:
            parents[x] = find(parents[x])
        return parents[x]

    def union(x, y):
        px, py = find(x), find(y)
        parents[px] = py

    for i in range(n):
        for j in range(i + 1, n):
            if graph[i][j] == 0: continue
            union(i, j)  # two nodes
    area = collections.defaultdict(int)  # affected area
    for i in range(n):
        area[find(i)] += 1
    malware = collections.defaultdict(int)  # how many different malware
    for x in initial:
        malware[find(x)] += 1
    # return min index if all malwares affect same region
    # return index with max count
    res = min(initial)
    count = 0
    for x in initial:
        px = find(x)
        if malware[px] != 1: continue  # no need to consider same malware
        if area[px] > count or (area[px] == count and res > x):  # update malware
            res = x
            count = area[px]
    return res


# concise union find, TC:O(N^2*alpha(N)), SC:O(N)
def minMalwareSpread2(graph: List[List[int]], initial: List[int]) -> int:
    n = len(graph)
    parents = [i for i in range(n)]
    # connect all nodes => count nodes of each groups => find parents from initial
    def find(x):
        if x != parents[x]:
            parents[x] = find(parents[x])
        return parents[x]
    def union(x, y):
        px, py = find(x), find(y)
        parents[px] = py
    # TC:O(N^2*alpha(N))
    for i in range(n):
        for j in range(i+1, n):
            if graph[i][j] == 0: continue
            union(i, j) # two nodes
    area = collections.Counter(find(i) for i in range(n)) # affected area
    malware = collections.Counter(find(i) for i in initial) # how many different
    # return min index if all malwares affect same region
    # return index with max count
    return min(initial, key=lambda x: [(malware[find(x)]==1) * (-area[find(x)]), x])
